import sys
import numpy as np
import torch
import torchvision.transforms as transforms
import torchvision.datasets as dsets
import torchvision.models as models
from scipy.special import comb
from torch.utils.data.sampler import SubsetRandomSampler
from utils.utils_models import linear_model, mlp_model
from cifar_models import resnet
import math
from scipy.io import arff
import pandas as pd
from sklearn.model_selection import train_test_split
from scipy.io import loadmat
import scipy.io as sio

def synth_uupr_train_dataset(args, prior, positive_train_data, negative_train_data):
    positive_train_data = positive_train_data[torch.randperm(positive_train_data.shape[0])]
    negative_train_data = negative_train_data[torch.randperm(negative_train_data.shape[0])] # shuffle the given data
    small_positive_train_data =positive_train_data[:int(0.1*positive_train_data.shape[0]),:]
    small_negative_train_data = negative_train_data[:int(0.1*negative_train_data.shape[0]),:]
    pn_prior = prior*(1-prior)
    pp_prior = prior*prior
    nn_prior = (1-prior)*(1-prior)
    n = args.n
    if args.uci == 0 or args.uci == 1:
        total_prior = pn_prior + pp_prior + nn_prior
        pn_number = math.floor(n * (pn_prior/total_prior))
        pp_number = math.floor(n* (pp_prior/total_prior))
        nn_number = n - pn_number - pp_number
    elif args.uci == 2:
        total_prior = pn_prior + pp_prior + nn_prior
        pp_number = math.floor(n* (pp_prior/total_prior))
        pn_number = positive_train_data.shape[0] - 2*pp_number
        nn_number = n - pn_number - pp_number

    print("#pn pairs: ", pn_number, "#pp pairs", pp_number, "#nn pairs", nn_number)
    print("the number of actually used positive data is:", pn_number + 2*pp_number)
    print("the number of actually used negative data is:", pn_number + 2*nn_number)
    if args.me == "small":
        xpn_p = small_positive_train_data[:pn_number, :]
        xpn_n = small_negative_train_data[:pn_number, :]
        xpp_p1 = small_positive_train_data[pn_number:(pn_number+pp_number), :]
        xpp_p2 = small_positive_train_data[(pn_number+pp_number):(pn_number+2*pp_number), :]
        xnn_n1 = small_negative_train_data[pn_number:(pn_number+nn_number), :]
        xnn_n2 = small_negative_train_data[pn_number+nn_number:pn_number+2*nn_number, :]        
    else:
        xpn_p = positive_train_data[:pn_number, :]
        xpn_n = negative_train_data[:pn_number, :]
        xpp_p1 = positive_train_data[pn_number:(pn_number+pp_number), :]
        xpp_p2 = positive_train_data[(pn_number+pp_number):(pn_number+2*pp_number), :]
        xnn_n1 = negative_train_data[pn_number:(pn_number+nn_number), :]
        xnn_n2 = negative_train_data[pn_number+nn_number:pn_number+2*nn_number, :]

    x1 = torch.cat((xpn_p, xpp_p1, xnn_n1), dim=0)
    x2 = torch.cat((xpn_n, xpp_p2, xnn_n2), dim=0)
    real_y1 = torch.cat((torch.ones(pn_number), torch.ones(pp_number), -torch.ones(nn_number)), dim=0)
    real_y2 = torch.cat((-torch.ones(pn_number), torch.ones(pp_number), -torch.ones(nn_number)), dim=0)
    given_y1 = torch.cat((torch.ones(pn_number), torch.ones(pp_number), torch.ones(nn_number)), dim=0)
    given_y2 = torch.cat((-torch.ones(pn_number), -torch.ones(pp_number), -torch.ones(nn_number)), dim=0)
    print(x1.shape, x2.shape, real_y1.shape, real_y2.shape, given_y1.shape, given_y2.shape)
    return x1, x2, real_y1, real_y2, given_y1, given_y2

def synth_test_dataset( prior, positive_test_data, negative_test_data):
    num_p = positive_test_data.shape[0]
    num_n = negative_test_data.shape[0]
    if prior == 0.2:
        nn = num_n
        np = int(num_n*0.25)
    elif prior == 0.5:
        if num_p > num_n:
          nn = num_n
          np = num_n
        else:
          nn = num_p
          np = num_p
    elif prior == 0.8:
        np = num_p
        nn = int(num_p*0.25)
    else:
        np = num_p
        nn = num_n        
    x = torch.cat((positive_test_data[:np, :], negative_test_data[:nn, :]), dim=0)
    y = torch.cat((torch.ones(np), -torch.ones(nn)), dim=0)
    return x, y

def generate_uupr_loaders(xp, xn, given_yp, given_yn, real_yp, real_yn, xt, yt, batch_size):
    x_train = torch.cat((xp, xn), dim=0)
    y_given = torch.cat((given_yp, given_yn), dim=0)
    y_real = torch.cat((real_yp, real_yn), dim=0)
    print(x_train.shape, y_given.shape, y_real.shape)
    given_train_set = torch.utils.data.TensorDataset(x_train, y_given)
    real_train_set = torch.utils.data.TensorDataset(x_train, y_real)
    test_set = torch.utils.data.TensorDataset(xt, yt)
    uupr_pos_train_set = torch.utils.data.TensorDataset(xp, given_yp)
    uupr_neg_train_set = torch.utils.data.TensorDataset(xn, given_yn)
    uupr_pos_train_loader = torch.utils.data.DataLoader(dataset=uupr_pos_train_set, batch_size=batch_size, shuffle=True, num_workers=8)
    uupr_neg_train_loader = torch.utils.data.DataLoader(dataset=uupr_neg_train_set, batch_size=batch_size, shuffle=True, num_workers=8)
    given_train_loader = torch.utils.data.DataLoader(dataset=given_train_set, batch_size=batch_size, shuffle=True, num_workers=8)
    real_train_loader = torch.utils.data.DataLoader(dataset=real_train_set, batch_size=batch_size, shuffle=True, num_workers=8)
    test_loader = torch.utils.data.DataLoader(dataset=test_set, batch_size=batch_size, shuffle=True, num_workers=8)
    return uupr_pos_train_loader, uupr_neg_train_loader, given_train_loader, real_train_loader, test_loader
    
def generate_small_loaders(xp, xn, given_yp, given_yn, real_yp, real_yn, xt, yt, batch_size):
    x_train = torch.cat((xp, xn), dim=0)
    y_given = torch.cat((given_yp, given_yn), dim=0)
    y_real = torch.cat((real_yp, real_yn), dim=0)
    n = y_real.shape[0]
    real_num = int(n*0.1)
    real_y = y_real[:real_num]
    given_y = y_given[real_num:]
    small_y = torch.cat((real_y, given_y), dim=0)
    small_train_set = torch.utils.data.TensorDataset(x_train, small_y)
    test_set = torch.utils.data.TensorDataset(xt, yt)
    small_train_loader = torch.utils.data.DataLoader(dataset=small_train_set, batch_size=batch_size, shuffle=True, num_workers=8)
    test_loader = torch.utils.data.DataLoader(dataset=test_set, batch_size=batch_size, shuffle=True, num_workers=8)
    return small_train_loader, test_loader
    
def generate_original_binary_loaders(positive_train_data, negative_train_data, positive_test_data, negative_test_data, batch_size):
    positive_train_labels = torch.ones(positive_train_data.shape[0])
    negative_train_labels = -torch.ones(negative_train_data.shape[0])
    positive_test_labels = torch.ones(positive_test_data.shape[0])
    negative_test_labels = -torch.ones(negative_test_data.shape[0])
    train_data = torch.cat((positive_train_data, negative_train_data), dim=0)
    train_labels = torch.cat((positive_train_labels, negative_train_labels), dim=0)
    test_data = torch.cat((positive_test_data, negative_test_data), dim=0)
    test_labels = torch.cat((positive_test_labels, negative_test_labels), dim=0)
    binary_train_set = torch.utils.data.TensorDataset(train_data, train_labels)
    binary_test_set = torch.utils.data.TensorDataset(test_data, test_labels)
    binary_train_loader = torch.utils.data.DataLoader(dataset=binary_train_set, batch_size=batch_size, shuffle=True, num_workers=8)
    binary_test_loader = torch.utils.data.DataLoader(dataset=binary_test_set, batch_size=batch_size, shuffle=True, num_workers=8)
    return binary_train_loader, binary_test_loader

def convert_to_binary_data(dataname, train_data, train_labels, test_data, test_labels):
    train_index = torch.arange(train_labels.shape[0])
    test_index = torch.arange(test_labels.shape[0])
    if dataname == 'cifar10':
        positive_train_index = torch.cat((torch.cat((torch.cat((torch.cat((torch.cat((train_index[train_labels==2],train_index[train_labels==3]),dim=0),train_index[train_labels==4]), dim=0),train_index[train_labels==5]),dim=0),train_index[train_labels==6]),dim=0),train_index[train_labels==7]),dim=0)
        negative_train_index = torch.cat((torch.cat((torch.cat((train_index[train_labels==0],train_index[train_labels==1]),dim=0),train_index[train_labels==8]),dim=0),train_index[train_labels==9]),dim=0)
        positive_test_index = torch.cat((torch.cat((torch.cat((torch.cat((torch.cat((test_index[test_labels==2],test_index[test_labels==3]),dim=0),test_index[test_labels==4]), dim=0),test_index[test_labels==5]),dim=0),test_index[test_labels==6]),dim=0),test_index[test_labels==7]),dim=0)
        negative_test_index = torch.cat((torch.cat((torch.cat((test_index[test_labels==0],test_index[test_labels==1]),dim=0),test_index[test_labels==8]),dim=0),test_index[test_labels==9]),dim=0)
    elif dataname == 'mnist' or dataname == 'fashion' or dataname == 'kmnist' or dataname == 'svhn':
        positive_train_index = torch.cat((torch.cat((torch.cat((torch.cat((train_index[train_labels==0],train_index[train_labels==2]),dim=0),train_index[train_labels==4]),dim=0),train_index[train_labels==6]),dim=0),train_index[train_labels==8]),dim=0)
        negative_train_index = torch.cat((torch.cat((torch.cat((torch.cat((train_index[train_labels==1],train_index[train_labels==3]),dim=0),train_index[train_labels==5]),dim=0),train_index[train_labels==7]),dim=0),train_index[train_labels==9]),dim=0)
        positive_test_index = torch.cat((torch.cat((torch.cat((torch.cat((test_index[test_labels==0],test_index[test_labels==2]),dim=0),test_index[test_labels==4]),dim=0),test_index[test_labels==6]),dim=0),test_index[test_labels==8]),dim=0)
        negative_test_index = torch.cat((torch.cat((torch.cat((torch.cat((test_index[test_labels==1],test_index[test_labels==3]),dim=0),test_index[test_labels==5]),dim=0),test_index[test_labels==7]),dim=0),test_index[test_labels==9]),dim=0)
    else:
        positive_train_index = train_index[train_labels==1]
        negative_train_index = train_index[train_labels==-1]
        positive_test_index = test_index[test_labels==1]
        negative_test_index = test_index[test_labels==-1]
    positive_train_data = train_data[positive_train_index, :].float()
    negative_train_data = train_data[negative_train_index, :].float()
    positive_test_data = test_data[positive_test_index, :].float()
    negative_test_data = test_data[negative_test_index, :].float()
    return positive_train_data, negative_train_data, positive_test_data, negative_test_data

def prepare_mnist_data():
    ordinary_train_dataset = dsets.MNIST(root='./data/mnist', train=True, transform=transforms.ToTensor(), download=True)
    test_dataset = dsets.MNIST(root='./data/mnist', train=False, transform=transforms.ToTensor())
    train_data = ordinary_train_dataset.data.reshape(-1, 1, 28, 28)
    train_labels = ordinary_train_dataset.targets
    test_data = test_dataset.data.reshape(-1, 1, 28, 28)
    test_labels = test_dataset.targets
    dim = 28*28
    return train_data, train_labels, test_data, test_labels, dim

def prepare_kmnist_data():
    ordinary_train_dataset = dsets.KMNIST(root='./data/KMNIST', train=True, transform=transforms.ToTensor(), download=True)
    test_dataset = dsets.KMNIST(root='./data/KMNIST', train=False, transform=transforms.ToTensor())
    train_data = ordinary_train_dataset.data.reshape(-1, 1, 28, 28)
    train_labels = ordinary_train_dataset.targets
    test_data = test_dataset.data.reshape(-1, 1, 28, 28)
    test_labels = test_dataset.targets
    dim = 28*28
    return train_data, train_labels, test_data, test_labels, dim

def prepare_fashion_data():
    ordinary_train_dataset = dsets.FashionMNIST(root='./data/FashionMnist', train=True, transform=transforms.ToTensor(), download=True)
    test_dataset = dsets.FashionMNIST(root='./data/FashionMnist', train=False, transform=transforms.ToTensor())
    train_data = ordinary_train_dataset.data.reshape(-1, 1, 28, 28)
    train_labels = ordinary_train_dataset.targets
    test_data = test_dataset.data.reshape(-1, 1, 28, 28)
    test_labels = test_dataset.targets
    dim = 28*28
    return train_data, train_labels, test_data, test_labels, dim

def prepare_cifar10_data():
    train_transform = transforms.Compose(
        [transforms.ToTensor(), # transforms.RandomHorizontalFlip(), transforms.RandomCrop(32,4),
         transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))])
    test_transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))])
    ordinary_train_dataset = dsets.CIFAR10(root='./data', train=True, transform=train_transform, download=True)
    test_dataset = dsets.CIFAR10(root='./data', train=False, transform=test_transform)
    train_data = torch.from_numpy(ordinary_train_dataset.data) # because data is a numpy type
    dim0, dim1, dim2, dim3 = train_data.shape # dim3 = 3
    train_data = train_data.reshape(dim0, dim3, dim1, dim2).float()
    train_labels = ordinary_train_dataset.targets
    train_labels = torch.tensor(train_labels).float() # because train_labels is a list type
    test_data = torch.from_numpy(test_dataset.data)
    dim0, dim1, dim2, dim3 = test_data.shape # dim3 = 3
    test_data = test_data.reshape(dim0, dim3, dim1, dim2).float()
    test_labels = test_dataset.targets
    test_labels = torch.tensor(test_labels).float()
    dim = 28*28
    return train_data, train_labels, test_data, test_labels, dim

def prepare_uci1_data(args):
    if args.ds  == 'usps' or args.ds  == 'pendigits':
        dataname = "data/uci_"+args.ds+".mat" 
        current_data = sio.loadmat(dataname)
        data = current_data['data']
        label = current_data['label']
        data = torch.from_numpy(data).float()
        label = torch.from_numpy(label).float()
        labels = label.argmax(dim=1)
        labels[labels==10] = 0
    else:
        loaded_dataset = 'uci_data_arff/' + args.ds + '.arff'
        data = arff.loadarff(loaded_dataset)
        df = pd.DataFrame(data[0])
        data_label_matrix = df.values.astype('float')
        data = data_label_matrix[:,:-1]
        label = data_label_matrix[:,-1].astype('long')
        data = torch.from_numpy(data).float()
        labels = torch.from_numpy(label).float()
        #labels = label.argmax(dim=1)
    train_index = torch.arange(labels.shape[0])
    positive_index = torch.cat((torch.cat((torch.cat((torch.cat((train_index[labels==0],train_index[labels==2]),dim=0),train_index[labels==4]),dim=0),train_index[labels==6]),dim=0),train_index[labels==8]),dim=0)
    negative_index = torch.cat((torch.cat((torch.cat((torch.cat((train_index[labels==1],train_index[labels==3]),dim=0),train_index[labels==5]),dim=0),train_index[labels==7]),dim=0),train_index[labels==9]),dim=0)
    positive_data = data[positive_index,:]
    negative_data = data[negative_index,:]
    np = positive_data.shape[0]
    nn = negative_data.shape[0]
    positive_data = positive_data[torch.randperm(positive_data.shape[0])]
    negative_data = negative_data[torch.randperm(negative_data.shape[0])]
    train_p = int(np*0.8)
    train_n = int(nn*0.8)
    positive_train_data  = positive_data[:train_p,:]
    positive_test_data = positive_data[train_p:,:]
    negative_train_data = negative_data[:train_n,:]
    negative_test_data = negative_data[train_n:,:]
    num_train = train_p +train_n
    num_test = (np+nn)-num_train
    dim = positive_train_data.shape[1]
    return positive_train_data, negative_train_data, positive_test_data, negative_test_data, num_train, num_test,dim


def get_model(args, dim, device):
    if args.ds == 'cifar10' or args.ds == 'svhn' or args.ds == 'cifar100':
        if args.mo == 'resnet':
            model = resnet(depth=32, num_classes=1).to(device)
        else:
            model = densenet(num_classes=1).to(device)
    else:
        if args.mo == 'linear':
            model = linear_model(input_dim=dim, output_dim=1).to(device)
        elif args.mo == 'mlp':
            model = mlp_model(input_dim=dim, hidden_dim=300, output_dim=1).to(device)
    return model

def generate_uupr_data(args):
    if args.uci == 0:
        if args.ds == 'mnist':
            train_data, train_labels, test_data, test_labels, dim = prepare_mnist_data()
        elif args.ds == 'kmnist':
            train_data, train_labels, test_data, test_labels, dim = prepare_kmnist_data()   
        elif args.ds == 'fashion':
            train_data, train_labels, test_data, test_labels, dim = prepare_fashion_data()
        elif args.ds == 'cifar10':
            train_data, train_labels, test_data, test_labels, dim = prepare_cifar10_data()
        elif args.ds == 'cifar100':
            train_data, train_labels, test_data, test_labels, dim = prepare_cifar100_data()
        print("#original train:", train_data.shape, "#original test", test_data.shape)    
        positive_train_data, negative_train_data, positive_test_data, negative_test_data = convert_to_binary_data(args.ds, train_data, train_labels, test_data, test_labels)
        print("#all train positive:", positive_train_data.shape, "#all train negative:", negative_train_data.shape, "#all test positive:", positive_test_data.shape, "#all test negative:", negative_test_data.shape)
        x1, x2, real_y1, real_y2, given_y1, given_y2 = synth_uupr_train_dataset(args, args.prior, positive_train_data, negative_train_data)
        xt, yt = synth_test_dataset(args.prior, positive_test_data, negative_test_data)
        print("#test positive:", (yt==1).sum(), "#test negative:", (yt==-1).sum())
    elif args.uci == 1:  
        positive_train_data, negative_train_data, positive_test_data, negative_test_data, num_train, num_test, dim= prepare_uci1_data(args)
        print("#original train:", num_train, "#original test", num_test)    
        print("#all train positive:", positive_train_data.shape, "#all train negative:", negative_train_data.shape)
        x1, x2, real_y1, real_y2, given_y1, given_y2 = synth_uupr_train_dataset(args, args.prior, positive_train_data, negative_train_data)
        xt, yt = synth_test_dataset(args.prior, positive_test_data, negative_test_data)
        print("#test positive:", (yt==1).sum(), "#test negative:", (yt==-1).sum())
    return x1, x2, real_y1, real_y2, given_y1, given_y2, xt, yt, dim